
import time
import numpy as np
from datetime import datetime
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
sharegpt_ds = load_dataset("RyokoAI/ShareGPT52K", streaming=True)


def generate_arrival_list(conversations):
    """
    Generate a list of conversation arrivals with the following information:
    - conversation_id: ID of the conversation
    - history_length: Length of conversation history (Q+A from previous turns + current Q)
    - response_length: Length of the response (A) for the current turn
    - time_step: Time step of this arrival (sequential integer)
    - next_arrival_time: Time step of the next arrival for this conversation (or -1 if none)
    - last_arrival_time: Time step of the previous arrival for this conversation (or -1 if none)
    """
    print("Generating arrival list...")
    
    # Diagnostic counters
    total_input_convs = len(conversations)

    # First pass: collect all turn data
    conversation_data = []
    user_promt_lengths = []
    model_response_lengths =[]
    print(f"Processing {total_input_convs} conversations...")
    for conv_idx, row in tqdm(enumerate(conversations), total=total_input_convs):
        # conversation_length = len(row['conversation'])
        conversation = row['conversation']
        
        # Check if conversation has any turns
        if len(conversation) == 0:
            continue
        
        # Track conversation history for this conversation
        history_tokens = 0
        
        for i in range(0, len(conversation), 2):
            if i + 1 >= len(conversation):  # Check if we have both user and assistant message
                break
                
            user_message = conversation[i]
            assistant_message = conversation[i + 1]
            
            # Current turn content
            Q_content = user_message['content']
            A_content = assistant_message['content']
            
            # Current history is all previous turns plus current Q
            question_tokens =  len(tokenizer.encode(Q_content))
            user_promt_lengths.append(len(tokenizer.encode(Q_content)))
            history_tokens += question_tokens
            # Get token sizes
            response_tokens = len(tokenizer.encode(A_content))
            model_response_lengths.append(response_tokens)
            try:
                next_arrival_time = conversation[i + 3]['timestamp'].timestamp()
                next_prompt_length = len(tokenizer.encode(conversation[i + 2]['content']))
            except:
                next_arrival_time = float('inf') # if no next turn, set to -1
                next_prompt_length = -1 # if no next turn, set to -1
            # Add to conversation data
            try:
                conversation_data.append({
                    "conv_idx": conv_idx,
                    "timestamp": assistant_message["timestamp"].timestamp(),
                    "history_length": history_tokens,
                    "response_length": response_tokens,
                    "next_arrival_time": next_arrival_time,
                    "next_prompt_length": next_prompt_length
                })
            except:
                print(f"Error processing conversation {conv_idx}")
                continue

            history_tokens += response_tokens # also add response length
    
    # Print diagnostics
    print(f"Input conversations: {total_input_convs}")
    conversation_data.sort(key=lambda x: x["timestamp"])
    # Return some basic statistics
    return conversation_data, user_promt_lengths, model_response_lengths

def add_timestamps_to_sharegpt(sharegpt_data, lambda_conv=1.0, lambda_turn=10.0):
    """
    Add synthetic timestamps to ShareGPT conversations based on theoretical model.
    Only adds timestamps to assistant messages.
    
    Parameters:
    - sharegpt_data: List of ShareGPT conversation entries
    - lambda_conv: Rate parameter for conversation interarrival times (new conversations)
    - lambda_turn: Rate parameter for turn interarrival times (within conversations)
    
    Returns:
    - List of conversations in Wildchat format with timestamps
    """
    print(f"Adding timestamps with λ_conv={lambda_conv}, λ_turn={lambda_turn}")
    print(f"Ratio λ_conv/λ_turn = {lambda_conv/lambda_turn:.5f}")
    
    # Start time for the simulation (use Unix timestamp)
    base_time = time.time()
    
    # Generate interarrival times for conversations using exponential distribution
    conv_interarrivals = np.random.exponential(scale=1/lambda_conv, size=len(sharegpt_data))
    
    # Calculate start times for each conversation
    conv_start_times = np.cumsum(conv_interarrivals) + base_time
    
    # List to store conversations in Wildchat format
    wildchat_format_data = []
    
    # Process each conversation
    print(f"Processing {len(sharegpt_data)} conversations...")
    
    for i, (entry, start_time) in enumerate(zip(tqdm(sharegpt_data), conv_start_times)):
        if 'conversations' not in entry or not entry['conversations']:
            continue
        num_turns = len(entry['conversations'])
        if num_turns % 2 != 0:
            continue
        
        # Create a new conversation in Wildchat format
        wildchat_entry = {'conversation': []}
        
        # Process all turns in the conversation
        curr_time = start_time
        
        for j in range(0, len(entry['conversations'])):
            turn = entry['conversations'][j]
            
            if turn.get('from') not in ['human', 'gpt']:
                continue
            
            # Convert to Wildchat format
            role = 'user' if turn.get('from') == 'human' else 'assistant'
            content = turn.get('value', '')
            
            # Only create timestamp for assistant messages
            if role == 'assistant':
                # Create timestamp (as datetime object)
                timestamp = datetime.fromtimestamp(curr_time)
                
                # Generate next turn time using exponential distribution
                if j + 1 < len(entry['conversations']):
                    turn_interarrival = np.random.exponential(scale=1/lambda_turn)
                    curr_time += turn_interarrival
            else:
                # No timestamp for user messages
                timestamp = None
            
            # Add to Wildchat conversation
            wildchat_entry['conversation'].append({
                'role': role,
                'content': content,
                'timestamp': timestamp
            })
        
        # Add to list if valid turns exist
        if len(wildchat_entry['conversation']) > 0:
            wildchat_format_data.append(wildchat_entry)
    
    print(f"Converted {len(wildchat_format_data)} conversations to Wildchat format with timestamps")
    print(f"Timestamps added only to assistant messages")
    
    return wildchat_format_data

sharegpt_data = list(sharegpt_ds['train'].take(200))  # Adjust sample size as needed
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
sharegpt_data_easy_w_timestamps = add_timestamps_to_sharegpt(
    sharegpt_data, 
    lambda_conv=1,  # Slow arrival of new conversations
    lambda_turn=3.0  # Fast turns within conversations
)

# Process with the existing function
arrival_list_ShareGPT, user_prompt_lengths_ShareGPT, model_response_lengths_WildChat_ShareGPT = generate_arrival_list(sharegpt_data_easy_w_timestamps)


type(arrival_list_ShareGPT[1])
# it's a dict, save it to a json file
import json
# create the file if not exist
with open('./saved_data/arrival_list_ShareGPT.json', 'w') as f:
    json.dump(arrival_list_ShareGPT, f)

# with open('./saved_data/user_prompt_lengths_ShareGPT_lambda_conv_1_lambda_turn_3.json', 'w') as f:
#     json.dump(user_prompt_lengths_ShareGPT, f)

# with open('./saved_data/model_response_lengths_ShareGPT_lambda_conv_1_lambda_turn_3.json', 'w') as f:
#     json.dump(model_response_lengths_WildChat_ShareGPT, f)
    

# with open('./saved_data/model_response_lengths_WildChat_ShareGPT.json', 'w') as f:


    

